# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Bounding Box List definition.

BoxList represents a list of bounding boxes as tensorflow
tensors, where each bounding box is represented as a row of 4 numbers,
[y_min, x_min, y_max, x_max].  It is assumed that all bounding boxes
within a given list correspond to a single image.  See also
box_list_ops.py for common box related operations (such as area, iou, etc).

Optionally, users can add additional related fields (such as weights).
We assume the following things to be true about fields:
* they correspond to boxes in the box_list along the 0th dimension
* they have inferrable rank at graph construction time
* all dimensions except for possibly the 0th can be inferred
  (i.e., not None) at graph construction time.

Some other notes:
  * Following tensorflow conventions, we use height, width ordering,
  and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
  * Tensors are always provided as (flat) [N, 4] tensors.
"""

import tensorflow as tf


class BoxList(object):
    """Box collection."""

    def __init__(self, boxes):
        """Constructs box collection.

        Args:
          boxes: a tensor of shape [N, 4] representing box corners

        Raises:
          ValueError: if invalid dimensions for bbox data or if bbox data is
            not in float32 format.
        """
        if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
            raise ValueError('Invalid dimensions for box data.')
        if boxes.dtype != tf.float32:
            raise ValueError('Invalid tensor type: should be tf.float32')
        self.data = {'boxes': boxes}

    def num_boxes(self):
        """Returns number of boxes held in collection.

        Returns:
          a tensor representing the number of boxes held in the collection.
        """
        return tf.shape(self.data['boxes'])[0]

    def num_boxes_static(self):
        """Returns number of boxes held in collection.

        This number is inferred at graph construction time rather than run-time.

        Returns:
          Number of boxes held in collection (integer) or None if this is not
            inferrable at graph construction time.
        """
        return self.data['boxes'].get_shape()[0].value

    def get_all_fields(self):
        """Returns all fields."""
        return self.data.keys()

    def get_extra_fields(self):
        """Returns all non-box fields (i.e., everything not named 'boxes')."""
        return [k for k in self.data.keys() if k != 'boxes']

    def add_field(self, field, field_data):
        """Add field to box list.

        This method can be used to add related box data such as
        weights/labels, etc.

        Args:
          field: a string key to access the data via `get`
          field_data: a tensor containing the data to store in the BoxList
        """
        self.data[field] = field_data

    def has_field(self, field):
        return field in self.data

    def get(self):
        """Convenience function for accessing box coordinates.

        Returns:
          a tensor with shape [N, 4] representing box coordinates.
        """
        return self.get_field('boxes')

    def set(self, boxes):
        """Convenience function for setting box coordinates.

        Args:
          boxes: a tensor of shape [N, 4] representing box corners

        Raises:
          ValueError: if invalid dimensions for bbox data
        """
        if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
            raise ValueError('Invalid dimensions for box data.')
        self.data['boxes'] = boxes

    def get_field(self, field):
        """Accesses a box collection and associated fields.

        This function returns specified field with object;
        if no field is specified, it returns the box coordinates.

        Args:
          field: this optional string parameter can be used to specify
            a related field to be accessed.

        Returns:
          a tensor representing the box collection or an associated field.

        Raises:
          ValueError: if invalid field
        """
        if not self.has_field(field):
            raise ValueError('field ' + str(field) + ' does not exist')
        return self.data[field]

    def set_field(self, field, value):
        """Sets the value of a field.

        Updates the field of a box_list with a given value.

        Args:
          field: (string) name of the field to set value.
          value: the value to assign to the field.

        Raises:
          ValueError: if the box_list does not have specified field.
        """
        if not self.has_field(field):
            raise ValueError('field %s does not exist' % field)
        self.data[field] = value

    def get_center_coordinates_and_sizes(self, scope=None):
        """Computes the center coordinates, height and width of the boxes.

        Args:
          scope: name scope of the function.

        Returns:
          a list of 4 1-D tensors [ycenter, xcenter, height, width].
        """
        with tf.name_scope(scope, 'get_center_coordinates_and_sizes'):
            box_corners = self.get()
            ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners))
            width = xmax - xmin
            height = ymax - ymin
            ycenter = ymin + height / 2.
            xcenter = xmin + width / 2.
            return [ycenter, xcenter, height, width]

    def transpose_coordinates(self, scope=None):
        """Transpose the coordinate representation in a boxlist.

        Args:
          scope: name scope of the function.
        """
        with tf.name_scope(scope, 'transpose_coordinates'):
            y_min, x_min, y_max, x_max = tf.split(
                value=self.get(), num_or_size_splits=4, axis=1)
            self.set(tf.concat([x_min, y_min, x_max, y_max], 1))

    def as_tensor_dict(self, fields=None):
        """Retrieves specified fields as a dictionary of tensors.

        Args:
          fields: (optional) list of fields to return in the dictionary.
            If None (default), all fields are returned.

        Returns:
          tensor_dict: A dictionary of tensors specified by fields.

        Raises:
          ValueError: if specified field is not contained in boxlist.
        """
        tensor_dict = {}
        if fields is None:
            fields = self.get_all_fields()
        for field in fields:
            if not self.has_field(field):
                raise ValueError('boxlist must contain all specified fields')
            tensor_dict[field] = self.get_field(field)
        return tensor_dict
